Extending the Context Window of LLMs

Large Language Models
Author

Maxime Lbonne

Published

July 7, 2023

📝 Article: https://kaiokendev.github.io/context

Problem

Problem: it is hard to extend the sequence length of a model.

  • Anil et al. (2022): the length extrapolation fails in part because of “distracting tokens” in the input during the PARITY task.
  • Chi et al. (2022): bias terms in positional encoding (like in ALiBi) replicate the effect of windowed attention by decaying token inter-dependency on long-range receptive fields (the tokens only focus on the tokens closest to them).
  • Tao et al. (2023) observe that, in long sequences, rear position embeddings are updates much fewer times than front position embeddings. They add random padding to the front patch of the sequence.
  • Liu et al. (2023): attention in long sequences starts to drift as we move to later positions and only attends to the most recent tokens.

Silver Linings

The attention mechanism seems destabilized in the case of long sequences due to an imbalance of attended tokens (either skewed to the front or the back).

Several solutions have been proposed:

  • Few-shot chain-of-thought reasoning and marker tokens
  • Length generalization/extrapolation can be learned ability to a certain extent (improves performance but not a silver bullet)
  • LLaMa 7B has been trained for retrieval over a 32K token window by introducing landmark tokens combined with a windowed-attention (blockwise computation).

Potential Solutions

  • Change the attention calculation: log(n) scaling (does help), relacing the softmax with ReLU in the attention equation (does not converge), etc.
  • Random Positional Encoding
  • Shifted Positional Encodings: shifting the tokens progressively along the desired length during the encoding step (failure).

Final Solution

Transformers do not learn how to gauge position based on the relative distance or the rotational factors, but memorize the tokens and their positional scaling factors.

  • Rotary positional embedding to loop the positions around after crossing the max context length (e.g., 2048): position_ids = position_ids % 2048
  • Block repeated positions: repeating the chosen frequency for a block of positions, so [1, 2, 3, 4, 5, …, L] becomes [1, 1, 1, 1, 2, 2, 2, 2, 3, …, L]. This is achieved by changing the frequency update: t *= 1/4.

In other words, several tokens (4 in this example) are assigned to the same position. This (surprising) scheme can quadruple the context length with minimal performance degradation (~2%). More information about it in this paper from Meta: https://arxiv.org/abs/2306.15595